package com.easy;

import com.easy.domain.ClassA;
import com.easy.domain.ClassB;
import com.easy.proxy.JdkProxy;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.RootBeanDefinition;

import java.lang.reflect.Field;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @ClassName CircularDependenciesTest
 * @Description TODO
 * @Author zheng
 * @Date 2023/4/25 10:19
 * @Version 1.0
 **/
public class CircularDependenciesTest {
    /**
     * BeanDefinitionMap存储bean定义
     */
    private static final Map<String, BeanDefinition> beanDefinitionMap = new ConcurrentHashMap<>();

    /**
     * 在非并发环境下,一级缓存可以解决循环依赖问题.但是在并发环境下,一级缓存中存放的有不完整对象,就会导致getBean时获得不完整对象.
     **/
    private static final Map<String, Object> singletonObjects = new ConcurrentHashMap<>();
    /**
     * 解决循环依赖并发环境下造成的getBean返回不完整对象问题
     **/
    private static final Map<String, Object> earlySingletonObjects = new ConcurrentHashMap<>();

    private static final Set<String> currentlyCreate = new HashSet<>();
    private static final Map<String, ObjectFactory> singletonFactories = new ConcurrentHashMap<>();


    public static void main(String[] args) throws Exception {
        loadBeanDefinitionMap();
        for (String beanName : beanDefinitionMap.keySet()) {
            getBean(beanName);
        }
        ClassA classA = (ClassA) getBean("classA");
        Object classB = getBean("classB");
        classA.test();
        System.out.println();
    }

    /**
     * @return java.lang.Object
     * @Description 获取bean
     * @Date 2023/4/25 11:52
     * @Param [beanName bean名称]
     **/
    private static Object getBean(String beanName) throws Exception {
        Object singleton = getSingleton(beanName);
        if (singleton != null) {
            return singleton;
        }
        singleton = getSingleton(beanName, () -> {
            try {
                return createBean(beanName);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
        return singleton;
    }


    public static <T> T createBean(String beanName) throws Exception {
        //正在创建中
        currentlyCreate.add(beanName);
        //获取bean定义
        RootBeanDefinition beanDefinition = (RootBeanDefinition) beanDefinitionMap.get(beanName);
        Class<?> beanClass = beanDefinition.getBeanClass();
        //实例化
        Object singleton = beanClass.newInstance();
        //earlySingletonObjects.put(beanName, singleton);
        if (currentlyCreate.contains(beanName)) {
            Object finalSingleton = singleton;
            singletonFactories.put(beanName, () -> new MyBeanPostProcessor().getEarlyBeanReference(finalSingleton, beanName));
        }

        //获取bean所有字段
        Field[] fields = beanClass.getDeclaredFields();
        //遍历所有字段,并判断字段上是否有特定注解
        for (Field field : fields) {
            Autowired annotation = field.getAnnotation(Autowired.class);
            if (annotation != null) {
                //设置可见性为true,否则设置时会失败
                field.setAccessible(true);
                String name = field.getName();
                Object dpt = getBean(name);
                field.set(singleton, dpt);
            }
        }
        //移除
        //singletonObjects.put(beanName, singleton);
        //earlySingletonObjects.remove(beanName);
        //currentlyCreate.remove(beanName);
        return (T) singleton;
    }


    private static <T> T getSingleton(String beanName, ObjectFactory<?> singletonFactory) {
        Object object;
        Object singleton = singletonObjects.get(beanName);
        if (singleton != null) {
            object = singleton;
        }
        //注意：此处是createBean
        object = singletonFactory.getObject();
        Object o = object;
        if (currentlyCreate.contains(beanName)) {
            o = getSingleton(beanName);
        }
        //完整的bean,放入一级缓存，并从二三级缓存移除对应的值
        singletonObjects.put(beanName, o);
        earlySingletonObjects.remove(beanName);
        singletonFactories.remove(beanName);
        return (T) object;
    }

    private static Object getSingleton(String beanName) {
        Object singleton = singletonObjects.get(beanName);
        if (singleton == null && currentlyCreate.contains(beanName)) {
            singleton = earlySingletonObjects.get(beanName);
            if (singleton == null) {
                ObjectFactory<?> objectFactory = singletonFactories.get(beanName);
                singleton = objectFactory.getObject();
                singletonFactories.remove(beanName);
                earlySingletonObjects.put(beanName, singleton);
            }
        }
        return singleton;
    }

    /**
     * @return void
     * @Description 模拟beanDefinitionMap生成
     * @Date 2023/4/25 11:36
     * @Param []
     **/
    private static void loadBeanDefinitionMap() {
        RootBeanDefinition aRbd = new RootBeanDefinition(ClassA.class);
        RootBeanDefinition bRbd = new RootBeanDefinition(ClassB.class);
        beanDefinitionMap.put("classA", aRbd);
        beanDefinitionMap.put("classB", bRbd);
    }
}
